# -*- coding: utf-8 -*-
"""
Created on Sun Nov 24 19:59:16 2019
读取数据，并进行svm训练
@author: EDMOND
"""

import numpy as np
from sklearn import svm
from scipy import signal
import pickle


paths = ['left', 'push&pull', 'right', 'zhua']
# 获取数据，并保存在4类矩阵中
left_sample = np.zeros((600, 448))
right_sample = np.zeros((600, 448))
push_pull_sample = np.zeros((600, 448))
catch_sample = np.zeros((600, 448))


b, a = signal.butter(8, [0.05,0.4], 'bandpass')

# 读取根目录下4种信号 把信号滤波 STFT ，将信号STFT的值放到四类样本数组中
for i in range(600):
    for path in paths:
        if path == 'left':
            left_data = np.loadtxt(path + '/' + str(i) + '.txt');
            data_filted = signal.filtfilt(b, a, left_data)  # data为要过滤的信号
            f, t, Zxx = signal.stft(data_filted, nperseg=30)
            Zxx=abs(Zxx)
            left_sample[i, :] = Zxx.reshape(1,448)
        elif path == 'right':
            right_data = np.loadtxt(path + '/' + str(i) + '.txt')
            data_filted = signal.filtfilt(b, a, right_data)  # data为要过滤的信号
            f, t, Zxx = signal.stft(data_filted, nperseg=30)
            Zxx=abs(Zxx)
            right_sample[i, :] = Zxx.reshape(1,448)


        elif path == 'push&pull':
            push_data = np.loadtxt(path + '/' + str(i) + '.txt')
            data_filted = signal.filtfilt(b, a, push_data)  # data为要过滤的信号
            f, t, Zxx = signal.stft(data_filted, nperseg=30)
            Zxx=abs(Zxx)
            push_pull_sample[i, :] = Zxx.reshape(1,448)

        elif path == 'zhua':
            zhua_data = np.loadtxt(path + '/' + str(i) + '.txt')
            data_filted = signal.filtfilt(b, a, zhua_data)  # data为要过滤的信号
            f, t, Zxx = signal.stft(data_filted, nperseg=30)
            Zxx=abs(Zxx)
            catch_sample[i, :] = Zxx.reshape(1,448)

# 将训练集合并，作为数据矩阵
sample = np.vstack((left_sample, right_sample, push_pull_sample, catch_sample))
#标签
label = np.vstack((np.ones((600, 1)), 2 * np.ones((600, 1)), 3 * np.ones((600, 1)), 4 * np.ones((600, 1))))


# 支持向量机，ovo的形式(每两类组成一个向量机)，多项式形式核函数，多项式度为2，其他取默认值
CLF = svm.SVC(decision_function_shape='ovr', max_iter=-1, cache_size=12000, kernel='poly', degree=3,C=10,coef0=10,gamma='auto');

# 划分训练集测试集
index = np.array([range(0, 100), range(600, 700), range(1200, 1300), range(1800, 1900)]).reshape(400, )
test_sample = sample[index, :]#测试集 数据
test_label = label[index, :]#测试集 标签
train_sample = np.delete(sample, index, axis=0) #训练集 数据
train_label = np.delete(label, index, axis=0)#训练集 标签
# 训练svm
CLF.fit(train_sample, train_label.reshape((2000,)))
with open('C:/Users/姚奕成/OneDrive/桌面/传感技术作业/data/clf.pickle','wb')as f: #python路径要用反斜杠
    pickle.dump(CLF,f) #将模型dump进f里面

# 使用测试集做预测
pred = CLF.predict(test_sample)
# 和测试集标签比较，得出准确率
accr = np.mean(pred == test_label.reshape(400, ))
print(accr )







